#!/usr/bin/env python

import math

import numpy as np
import scipy.sparse as sp


def cart2sph(x, y, z):
    """
    Check https://en.wikipedia.org/wiki/Spherical_coordinate_system#Cartesian_coordinates

    """
    xy = np.sqrt(x**2 + y**2)  # sqrt(x² + y²)
    x_2 = x**2
    y_2 = y**2
    z_2 = z**2

    r = np.sqrt(x_2 + y_2 + z_2)  # r = sqrt(x² + y² + z²)
    theta = np.arctan2(y, x)
    phi = np.arctan2(xy, z)
    return r, theta, phi


def calcangle(x1, x2):
    angle = math.degrees(
        math.acos(np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2)))
    )
    return angle


def gen_scanpath_structure(data):
    # everything into a dict
    # keep coordinates and durations of fixations
    fixations = dict(
        x=data["start_x"],
        y=data["start_y"],
        z=data["start_z"],
        dur=data["duration"],
    )
    # calculate saccade length and angle from vector lengths between fixations
    lenx = np.diff(data["start_x"])
    leny = np.diff(data["start_y"])
    lenz = np.diff(data["start_z"])
    # 3D goes with spherical coordinates.
    rho, theta, phi = cart2sph(lenx, leny, lenz)

    saccades = dict(
        # fixations are the start coordinates for saccades
        x=data[:-1]["start_x"],
        y=data[:-1]["start_y"],
        z=data[:-1]["start_z"],
        lenx=lenx,
        leny=leny,
        lenz=lenz,
        theta=theta,
        rho=rho,
        phi=phi,
    )
    return dict(fix=fixations, sac=saccades)


def keepsaccade(i, j, sim, data):
    for t, k in (
        ("sac", "lenx"),
        ("sac", "leny"),
        ("sac", "lenz"),
        ("sac", "x"),
        ("sac", "y"),
        ("sac", "z"),
        ("sac", "theta"),
        ("sac", "rho"),
        ("sac", "phi"),
        ("fix", "dur"),
    ):
        sim[t][k].insert(j, data[t][k][i])

    return i + 1, j + 1


def _get_empty_path():
    return dict(
        fix=dict(
            dur=[],
        ),
        sac=dict(
            x=[],
            y=[],
            z=[],
            lenx=[],
            leny=[],
            lenz=[],
            theta=[],
            rho=[],
            phi=[],
        ),
    )


def simlen(path, TAmp, TDur):
    # shortcuts
    saccades = path["sac"]
    fixations = path["fix"]

    if len(saccades["x"]) < 1:
        return path

    # the scanpath is long enough
    i = 0
    j = 0
    sim = _get_empty_path()
    # while we don't run into index errors
    while i <= len(saccades["x"]) - 1:
        # if saccade is the last one
        if i == len(saccades["x"]) - 1:
            # and if saccade has a length shorter than the threshold:
            if saccades["rho"][i] < TAmp:
                # and if the fixation duration is short:
                if (fixations["dur"][-1] < TDur) or (fixations["dur"][-2] < TDur):
                    # calculate sum of local vectors for simplification
                    v_x = saccades["lenx"][-2] + saccades["lenx"][-1]
                    v_y = saccades["leny"][-2] + saccades["leny"][-1]
                    v_z = saccades["lenz"][-2] + saccades["lenz"][-1]
                    rho, theta, phi = cart2sph(v_x, v_y, v_z)
                    # save them in the new vectors
                    sim["sac"]["lenx"][j - 1] = v_x
                    sim["sac"]["leny"][j - 1] = v_y
                    sim["sac"]["lenz"][j - 1] = v_z
                    sim["sac"]["theta"][j - 1] = theta
                    sim["sac"]["rho"][j - 1] = rho
                    sim["sac"]["phi"][j - 1] = phi
                    sim["fix"]["dur"].insert(j, fixations["dur"][i - 1])
                    j -= 1
                    i += 1
                # if fixation duration is longer than the threshold:
                else:
                    # insert original event data in new list -- no
                    # simplification
                    i, j = keepsaccade(i, j, sim, path)
            # if saccade does NOT have a length shorter than the threshold:
            else:
                # insert original path in new list -- no simplification
                i, j = keepsaccade(i, j, sim, path)
        # if saccade is not the last one
        else:
            # and if saccade has a length shorter than the threshold
            if (saccades["rho"][i] < TAmp) and (i < len(saccades["x"]) - 1):
                # and if fixation durations are short
                if (fixations["dur"][i + 1] < TDur) or (fixations["dur"][i] < TDur):
                    # calculate sum of local vectors in x and y length for
                    # simplification
                    v_x = saccades["lenx"][i] + saccades["lenx"][i + 1]
                    v_y = saccades["leny"][i] + saccades["leny"][i + 1]
                    v_z = saccades["lenz"][i] + saccades["lenz"][i + 1]
                    rho, theta, phi = cart2sph(v_x, v_y, v_z)
                    # save them in the new vectors
                    sim["sac"]["lenx"].insert(j, v_x)
                    sim["sac"]["leny"].insert(j, v_y)
                    sim["sac"]["lenz"].insert(j, v_z)
                    sim["sac"]["x"].insert(j, saccades["x"][i])
                    sim["sac"]["y"].insert(j, saccades["y"][i])
                    sim["sac"]["z"].insert(j, saccades["z"][i])
                    sim["sac"]["theta"].insert(j, theta)
                    sim["sac"]["rho"].insert(j, rho)
                    sim["sac"]["phi"].insert(j, phi)
                    # add the old fixation duration
                    sim["fix"]["dur"].insert(j, fixations["dur"][i])
                    i += 2
                    j += 1
                # if fixation durations longer than the threshold
                else:
                    # insert original path in new lists -- no simplification
                    i, j = keepsaccade(i, j, sim, path)
            # if saccade does NOT have a length shorter than the threshold:
            else:
                # insert original path in new list -- no simplification
                i, j = keepsaccade(i, j, sim, path)
    # append the last fixation duration
    sim["fix"]["dur"].append(fixations["dur"][-1])

    return sim


def simdir(path, TDir, TDur):
    # shortcuts
    saccades = path["sac"]
    fixations = path["fix"]

    if len(saccades["x"]) < 1:
        return path
    # the scanpath is long enough
    i = 0
    j = 0
    sim = _get_empty_path()
    # while we don't run into index errors
    while i <= len(saccades["x"]) - 1:
        if i < len(saccades["x"]) - 1:
            # lets check angles
            v1 = [saccades["lenx"][i], saccades["leny"][i], saccades["lenz"][i]]
            v2 = [
                saccades["lenx"][i + 1],
                saccades["leny"][i + 1],
                saccades["lenz"][i + 1],
            ]
            angle = calcangle(v1, v2)
        else:
            # an angle of infinite size won't go into any further loop
            angle = float("inf")
        # if the angle is smaller than the threshold and its not the last saccade
        if (angle < TDir) & (i < len(saccades["x"]) - 1):
            # if the fixation duration is short:
            if fixations["dur"][i + 1] < TDur:
                # calculate the sum of local vectors
                v_x = saccades["lenx"][i] + saccades["lenx"][i + 1]
                v_y = saccades["leny"][i] + saccades["leny"][i + 1]
                v_z = saccades["lenz"][i] + saccades["lenz"][i + 1]
                rho, theta, phi = cart2sph(v_x, v_y, v_z)
                # save them in the new vectors
                sim["sac"]["lenx"].insert(j, v_x)
                sim["sac"]["leny"].insert(j, v_y)
                sim["sac"]["lenz"].insert(j, v_z)
                sim["sac"]["x"].insert(j, saccades["x"][i])
                sim["sac"]["y"].insert(j, saccades["y"][i])
                sim["sac"]["z"].insert(j, saccades["z"][i])
                sim["sac"]["theta"].insert(j, theta)
                sim["sac"]["rho"].insert(j, rho)
                sim["sac"]["phi"].insert(j, phi)

                # add the fixation duration
                sim["fix"]["dur"].insert(j, fixations["dur"][i])
                i += 2
                j += 1
            else:
                # insert original data in new list -- no simplification
                i, j = keepsaccade(i, j, sim, path)

        else:
            # insert original path in new list -- no simplification
            i, j = keepsaccade(i, j, sim, path)
    # now append the last fixation duration
    sim["fix"]["dur"].append(fixations["dur"][-1])

    return sim


def simplify_scanpath(path, TAmp, TDir, TDur):
    prev_length = len(path["fix"]["dur"])
    while True:
        path = simdir(path, TDir, TDur)
        path = simlen(path, TAmp, TDur)
        length = len(path["fix"]["dur"])
        if length == prev_length:
            return path
        else:
            prev_length = length


def cal_vectordifferences(path1, path2):
    # take length in x and y direction of both scanpaths
    x1 = np.asarray(path1["sac"]["lenx"])
    x2 = np.asarray(path2["sac"]["lenx"])
    y1 = np.asarray(path1["sac"]["leny"])
    y2 = np.asarray(path2["sac"]["leny"])
    z1 = np.asarray(path1["sac"]["lenz"])
    z2 = np.asarray(path2["sac"]["lenz"])
    # initialize empty list for rows, will become matrix to store sacc-length
    # pairings

    rows = []
    # calculate saccade length differences, vectorized
    for i in range(0, len(x1)):
        x_diff = abs(x1[i] * np.ones(len(x2)) - x2)
        y_diff = abs(y1[i] * np.ones(len(y2)) - y2)
        z_diff = abs(z1[i] * np.ones(len(z2)) - z2)
        # calc final length from x and y lengths, append, stack into matrix M
        rows.append(np.asarray(np.sqrt(x_diff**2 + y_diff**2 + z_diff**2)))
    M = np.vstack(rows)
    return M


def createdirectedgraph(scanpath_dim, M, M_assignment):
    rows = []
    cols = []
    weight = []

    # loop through every node rowwise
    for i in range(0, scanpath_dim[0]):
        # loop through every node columnwise
        for j in range(0, scanpath_dim[1]):
            currentNode = i * scanpath_dim[1] + j
            # if in the last (bottom) row, only go right
            if (i == scanpath_dim[0] - 1) & (j < scanpath_dim[1] - 1):
                rows.append(currentNode)
                cols.append(currentNode + 1)
                weight.append(M[i, j + 1])

            # if in the last (rightmost) column, only go down
            elif (i < scanpath_dim[0] - 1) & (j == scanpath_dim[1] - 1):
                rows.append(currentNode)
                cols.append(currentNode + scanpath_dim[1])
                weight.append(M[i + 1, j])

            # if in the last (bottom-right) vertex, do not move any further
            elif (i == scanpath_dim[0] - 1) & (j == scanpath_dim[1] - 1):
                rows.append(currentNode)
                cols.append(currentNode)
                weight.append(0)

            # anywhere else, move right, down and down-right.
            else:
                rows.append(currentNode)
                rows.append(currentNode)
                rows.append(currentNode)
                cols.append(currentNode + 1)
                cols.append(currentNode + scanpath_dim[1])
                cols.append(currentNode + scanpath_dim[1] + 1)
                weight.append(M[i, j + 1])
                weight.append(M[i + 1, j])
                weight.append(M[i + 1, j + 1])

    rows = np.asarray(rows)
    cols = np.asarray(cols)
    weight = np.asarray(weight)
    numVert = scanpath_dim[0] * scanpath_dim[1]
    return numVert, rows, cols, weight


def dijkstra(numVert, rows, cols, data, start, end):
    # Create a scipy csr matrix from the rows,cols and append. This saves on memory.
    arrayWeightedGraph = (
        sp.coo_matrix((data, (rows, cols)), shape=(numVert, numVert))
    ).tocsr()

    # Run scipy's dijkstra and get the distance matrix and predecessors
    dist_matrix, predecessors = sp.csgraph.dijkstra(
        csgraph=arrayWeightedGraph, directed=True, indices=0, return_predecessors=True
    )

    # Backtrack thru the predecessors to get the reverse path
    path = [end]
    dist = float(dist_matrix[end])
    # If the predecessor is -9999, that means the index has no parent and thus we have reached the start node
    while end != -9999:
        path.append(predecessors[end])
        end = predecessors[end]

    # Return the path in ascending order and return the distance
    return path[-2::-1], dist


def cal_angulardifference(data1, data2, path, M_assignment):
    # get the angle between saccades from the scanpaths
    theta1 = data1["sac"]["theta"]
    theta2 = data2["sac"]["theta"]
    phi1 = data1["sac"]["phi"]
    phi2 = data2["sac"]["phi"]
    # initialize list to hold individual angle differences
    anglediff = []
    # calculate angular differences between the saccades along specified path
    for p in path:
        # which saccade indices correspond to path?
        i, j = np.where(M_assignment == p)
        # extract the angle
        spT = [theta1[i.item()], theta2[j.item()]]
        for t in range(0, len(spT)):
            # get results in range -pi, pi
            if spT[t] < 0:
                spT[t] = math.pi + (math.pi + spT[t])
        spT = abs(spT[0] - spT[1])
        if spT > math.pi:
            spT = 2 * math.pi - spT
        # extract the angle
        spP = [phi1[i.item()], phi2[j.item()]]
        for t in range(0, len(spP)):
            # get results in range -pi, pi
            if spP[t] < 0:
                spP[t] = math.pi + (math.pi + spP[t])
        spP = abs(spP[0] - spP[1])
        if spP > math.pi:
            spP = 2 * math.pi - spP
        sp = spT + spP
        if sp > math.pi:
            sp = 2 * math.pi - sp
        anglediff.append(sp)
    return anglediff


def cal_durationdifference(data1, data2, path, M_assignment):
    # get the duration of fixations in the scanpath
    dur1 = data1["fix"]["dur"]
    dur2 = data2["fix"]["dur"]
    # initialize list to hold individual duration differences
    durdiff = []
    # calculation fixation duration differences between saccades along path
    for p in path:
        # which saccade indices correspond to path?
        i, j = np.where(M_assignment == p)
        maxlist = [dur1[i.item()], dur2[j.item()]]
        # compute abs. duration diff, normalize by largest duration in pair
        durdiff.append(abs(dur1[i.item()] - dur2[j.item()]) / abs(max(maxlist)))
    return durdiff


def cal_lengthdifference(data1, data2, path, M_assignment):
    # get the saccade lengths rho
    len1 = np.asarray(data1["sac"]["rho"])
    len2 = np.asarray(data2["sac"]["rho"])
    # initialize list to hold individual length differences
    lendiff = []
    # calculate length differences between saccades along path
    for p in path:
        i, j = np.where(M_assignment == p)
        lendiff.append(abs(len1[i] - len2[j]))
    return lendiff


def cal_positiondifference(data1, data2, path, M_assignment):
    # get the x and y coordinates of points between saccades
    x1 = np.asarray(data1["sac"]["x"])
    x2 = np.asarray(data2["sac"]["x"])
    y1 = np.asarray(data1["sac"]["y"])
    y2 = np.asarray(data2["sac"]["y"])
    z1 = np.asarray(data1["sac"]["z"])
    z2 = np.asarray(data2["sac"]["z"])
    # initialize list to hold individual position differences
    posdiff = []
    # calculate position differences along path
    for p in path:
        i, j = np.where(M_assignment == p)
        posdiff.append(
            math.sqrt(
                (x1[i.item()] - x2[j.item()]) ** 2
                + (y1[i.item()] - y2[j.item()]) ** 2
                + (z1[i.item()] - z2[j.item()]) ** 2
            )
        )
    return posdiff


def cal_vectordifferencealongpath(data1, data2, path, M_assignment):
    # get the saccade lengths in x and y direction of both scanpaths
    x1 = np.asarray(data1["sac"]["lenx"])
    x2 = np.asarray(data2["sac"]["lenx"])
    y1 = np.asarray(data1["sac"]["leny"])
    y2 = np.asarray(data2["sac"]["leny"])
    z1 = np.asarray(data1["sac"]["lenz"])
    z2 = np.asarray(data2["sac"]["lenz"])
    # initialize list to hold individual vector differences
    vectordiff = []
    # calculate vector differences along path
    for p in path:
        i, j = np.where(M_assignment == p)
        vectordiff.append(
            np.sqrt(
                (x1[i.item()] - x2[j.item()]) ** 2
                + (y1[i.item()] - y2[j.item()]) ** 2
                + (z1[i.item()] - z2[j.item()]) ** 2
            )
        )
    return vectordiff


def getunnormalised(data1, data2, path, M_assignment):
    return [
        np.mean(fx(data1, data2, path, M_assignment))
        for fx in (
            cal_vectordifferencealongpath,
            cal_angulardifference,
            cal_lengthdifference,
            cal_positiondifference,
            cal_durationdifference,
        )
    ]


def normaliseresults(unnormalised, screensize):
    # normalize vector similarity against two times screen diagonal, the maximum
    # theoretical distance
    VectorSimilarity = 1 - unnormalised[0] / (
        2 * math.sqrt(screensize[0] ** 2 + screensize[1] ** 2 + screensize[2] ** 2)
    )
    # normalize against pi
    DirectionSimilarity = 1 - unnormalised[1] / math.pi
    # normalize against screen diagonal
    LengthSimilarity = 1 - unnormalised[2] / math.sqrt(
        screensize[0] ** 2 + screensize[1] ** 2 + screensize[2] ** 2
    )
    PositionSimilarity = 1 - unnormalised[3] / math.sqrt(
        screensize[0] ** 2 + screensize[1] ** 2 + screensize[2] ** 2
    )
    # no normalisazion necessary, already done
    DurationSimilarity = 1 - unnormalised[4]
    normalresults = [
        VectorSimilarity,
        DirectionSimilarity,
        LengthSimilarity,
        PositionSimilarity,
        DurationSimilarity,
    ]
    return normalresults


def docomparison(
    fixation_vectors1,
    fixation_vectors2,
    screensize=[32, 32, 32],
    grouping=False,  # by default we don't use this.
    TDir=0.0,
    TDur=0.0,
    TAmp=0.0,
):
    # check if fixation vectors/scanpaths are long enough
    if (len(fixation_vectors1) >= 3) & (len(fixation_vectors2) >= 3):
        # get the data into a geometric representation
        path1 = gen_scanpath_structure(fixation_vectors1)
        path2 = gen_scanpath_structure(fixation_vectors2)
        if grouping:
            # simplify the data
            path1 = simplify_scanpath(path1, TAmp, TDir, TDur)
            path2 = simplify_scanpath(path2, TAmp, TDir, TDur)

        # create M, a matrix of all vector pairings length differences (weights)
        M = cal_vectordifferences(path1, path2)

        # initialize a matrix of size M for a matrix of nodes
        scanpath_dim = np.shape(M)
        M_assignment = np.arange(scanpath_dim[0] * scanpath_dim[1]).reshape(
            scanpath_dim[0], scanpath_dim[1]
        )
        # create a weighted graph of all possible connections per Node, and their weight
        numVert, rows, cols, weight = createdirectedgraph(scanpath_dim, M, M_assignment)
        # find the shortest path (= lowest sum of weights) through the graph using scipy dijkstra
        path, dist = dijkstra(
            numVert, rows, cols, weight, 0, scanpath_dim[0] * scanpath_dim[1] - 1
        )

        # compute similarities on aligned scanpaths and normalize them
        unnormalised = getunnormalised(path1, path2, path, M_assignment)
        normal = normaliseresults(unnormalised, screensize)
        return normal
    # return nan as result if at least one scanpath it too short
    else:
        return np.repeat(np.nan, 5)
